Building a car classifier

Continuing as we were...

The plan going in:

  • Try some ways to deal with our class imbalance
  • Try transfer learning by fine-tuning squeezenet or VGG16
  • See how much we can gain via data augmentation
  • Perhaps go back and redo our pre-processing to use bounding boxes.
In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
# system
import os
import glob
import itertools as it
import operator
from collections import defaultdict
from StringIO import StringIO

# other libraries
import cPickle as pickle
import numpy as np 
import pandas as pd
import scipy.io  # for loading .mat files
import scipy.misc # for imresize
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from PIL import Image
import seaborn as sns
import requests
In [150]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Input, GlobalAveragePooling2D
from keras.utils import np_utils

# https://github.com/fchollet/keras/issues/4499
from keras.layers.core import K
from keras.callbacks import TensorBoard

# for name scopes to make TensorBoard look prettier (doesn't work well yet)
import tensorflow as tf 
In [4]:
# my code
from display import (visualize_keras_model, plot_training_curves,
                     plot_confusion_matrix)
from helpers import combine_histories
In [5]:
%matplotlib inline
sns.set_style("white")
p = sns.color_palette()

# repeatability:
np.random.seed(42)
In [6]:
data_root = os.path.expanduser("~/data/cars")

Load saved metadata

(the images are too big when stored raw -- faster to just reload from jpg)

In [7]:
from collections import namedtuple
Example = namedtuple('Example',
                     ['rel_path', 'x1', 'y1', 'x2','y2','cls','test'])
In [18]:
# Load data we saved in 10-cars.ipynb
with open('class_details.pkl') as f:
    loaded = pickle.load(f)
    macro_classes = loaded['macro_classes']
    macro_class_map = loaded['macro_class_map']
    cls_tuples = loaded['cls_tuples']
    classes = loaded['classes']
    examples = loaded['examples']
    by_class = loaded['by_class']
    by_car_type = loaded['by_car_type']

macro_class_map
Out[18]:
{u'Convertible': 0,
 u'Coupe': 1,
 u'Minivan': 2,
 u'Pickup': 3,
 u'SUV': 4,
 u'Sedan': 5,
 u'Van': 6,
 u'Wagon': 7}
In [9]:
resized_path = os.path.join(data_root,'resized_car_ims') 

Load the data again.

In [10]:
def gray_to_rgb(im):
    """
    Noticed (due to array projection error in code below) that there is at least
    one grayscale image in the dataset.
    We'll use this to convert.
    """
    w, h = im.shape
    ret = np.empty((w,h,3), dtype=np.uint8)
    ret[:,:,0] = im
    ret[:,:,1] = im
    ret[:,:,2] = im
    return ret
In [11]:
def load_examples(by_class, cls, limit=None):
    """
    Load examples for a class. Ignores test/train distinction -- 
    we'll do our own train/validation/test split later.
    
    Args:
        by_class: our above dict -- class_id -> [Example()]
        cls: which class to load
        limit: if not None, only load this many images.
        
    Returns:
        list of (X,y) tuples, one for each image.
            X: 3x227x227 ndarray of type uint8
            Y: class_id (will be equal to cls)
    """
    res = []
    to_load = by_class[cls]
    if limit:
        to_load = to_load[:limit]

    for ex in to_load:
        # load the resized image!
        img_path = os.path.join(data_root, 
                        ex.rel_path.replace('car_ims', 'resized_car_ims'))
        img = mpimg.imread(img_path)
        # handle any grayscale images
        if len(img.shape) == 2:
            img = gray_to_rgb(img)
        res.append((img, cls))
    return res
In [12]:
def split_examples(xs, valid_frac, test_frac):
    """
    Randomly splits the xs array into train, valid, test, with specified 
    percentages. Rounds down.
    
    Returns:
        (train, valid, test)
    """
    assert valid_frac + test_frac < 1
    
    n = len(xs)
    valid = int(valid_frac * n)
    test = int(test_frac * n)
    train = n - valid - test
    
    # don't change passed-in list
    shuffled = xs[:]
    np.random.shuffle(shuffled)

    return (shuffled[:train], 
            shuffled[train:train + valid], 
            shuffled[train + valid:])

# quick test
split_examples(range(10), 0.2, 0.4)
Out[12]:
([8, 1, 5, 0], [7, 2], [9, 4, 3, 6])
In [14]:
# Look at training data -- there's so little we can look at all of it

def plot_data(xs, ys, predicts):
    """Plot the images in xs, with corresponding correct labels
    and predictions.
    
    Args:
        xs: RGB or grayscale images with float32 values in [0,1].
        ys: one-hot encoded labels
        predicts: probability vectors (same dim as ys, normalized e.g. via softmax)
    """
    
    # sort all 3 by ys
    xs, ys, ps = zip(*sorted(zip(xs, ys, predicts), 
                             key=lambda tpl: tpl[1][0]))
    n = len(xs)
    rows = (n+9)/10
    fig, plots = plt.subplots(rows,10, sharex='all', sharey='all',
                             figsize=(20,2*rows), squeeze=False)
    for i in range(n):
        # read the image
        ax = plots[i // 10, i % 10]
        ax.axis('off')
        img = xs[i].reshape(227,227,-1) 

        if img.shape[-1] == 1: # Grayscale
            # Get rid of the unneeded dimension
            img = img.squeeze()
            # flip grayscale:
            img = 1-img 
            
        ax.imshow(img)
        # dot with one-hot vector picks out right element
        pcorrect = np.dot(ps[i], ys[i]) 
        if pcorrect > 0.8:
            color = "blue"
        else:
            color = "red"
        ax.set_title("{}   p={:.2f}".format(int(ys[i][0]), pcorrect),
                     loc='center', fontsize=18, color=color)
    return fig
In [15]:
# normalize the data, this time leaving it in color
def normalize_for_cnn(xs):
    ret = (xs / 255.0)
    return ret
In [16]:
def image_from_url(url):
    response = requests.get(url)
    img = Image.open(StringIO(response.content))
    return img
In [19]:
# Load images
IMG_PER_CAR = None # 20 # None to use all
valid_frac = 0.2
test_frac = 0.2

train = []
valid = []
test = []
for car_type, model_tuples in by_car_type.items():
    macro_class_id = macro_class_map[car_type]
    
    for model_tpl in model_tuples:
        cls = model_tpl[0]
        examples = load_examples(by_class, cls, limit=IMG_PER_CAR)
        # replace class labels with the id of the macro class
        examples = [(X, macro_class_id) for (X,y) in examples]
        # split each class separately, so all have same fractions of 
        # train/valid/test
        (cls_train, cls_valid, cls_test) = split_examples(
            examples,
            valid_frac, test_frac)
        # and add them to the overall train/valid/test sets
        train.extend(cls_train)
        valid.extend(cls_valid)
        test.extend(cls_test)

# ...and shuffle to make training work better.
np.random.shuffle(train)
np.random.shuffle(valid)
np.random.shuffle(test)
In [20]:
# We have lists of (X,Y) tuples. Let's unzip into lists of Xs and Ys.
X_train, Y_train = zip(*train)
X_valid, Y_valid = zip(*valid)
X_test, Y_test = zip(*test)

# and turn into np arrays of the right dimension.
def convert_X(xs):
    '''
    Take list of (w,h,3) images.
    Turn into an np array, change type to float32.
    '''
    return np.array(xs).astype('float32')
    
X_train = convert_X(X_train)
X_valid = convert_X(X_valid)
X_test = convert_X(X_test)
In [21]:
X_train.shape
Out[21]:
(9867, 227, 227, 3)
In [22]:
def convert_Y(ys, macro_classes):
    '''
    Convert to np array, make one-hot.
    Already ensured they're sequential from zero.
    '''
    n_classes = len(macro_classes)
    return np_utils.to_categorical(ys, n_classes)

Y_train = convert_Y(Y_train, macro_classes)
Y_valid = convert_Y(Y_valid, macro_classes)
Y_test = convert_Y(Y_test, macro_classes)
In [23]:
Y_train.shape
Out[23]:
(9867, 8)
In [24]:
# normalize the data, this time leaving it in color
X_train_norm = normalize_for_cnn(X_train)
X_valid_norm = normalize_for_cnn(X_valid)
X_test_norm = normalize_for_cnn(X_test)
In [25]:
# Let's use more or less the same model to start (num classes changes)
def cnn_model2(use_dropout=True):
    model = Sequential()
    nb_filters = 16
    pool_size = (2,2)
    filter_size = 3
    nb_classes = len(macro_classes)
    
    with tf.name_scope("conv1") as scope:
        model.add(Convolution2D(nb_filters, filter_size, 
                            input_shape=(227, 227, 3)))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("conv2") as scope:
        model.add(Convolution2D(nb_filters, filter_size))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("conv3") as scope:
        model.add(Convolution2D(nb_filters, filter_size))
        model.add(Activation('relu'))
        model.add(MaxPooling2D(pool_size=pool_size))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("dense1") as scope:
        model.add(Flatten())
        model.add(Dense(16))
        model.add(Activation('relu'))
        if use_dropout:
            model.add(Dropout(0.5))

    with tf.name_scope("softmax") as scope:
        model.add(Dense(nb_classes))
        model.add(Activation('softmax'))
    return model

# Uncomment if getting a "Invalid argument: You must feed a value
# for placeholder tensor ..." when rerunning training. 
# K.clear_session() # https://github.com/fchollet/keras/issues/4499
    

model3 = cnn_model2()
model3.compile(loss='categorical_crossentropy',
              optimizer='adadelta',
              metrics=['accuracy'])
In [26]:
# This model will train slowly, so let's checkpoint it periodically
from keras.callbacks import ModelCheckpoint
In [28]:
recompute = False

if recompute:
#     # Save info during computation so we can see what's happening
#     tbCallback = TensorBoard(
#         log_dir='./graph', histogram_freq=1, 
#         write_graph=False, write_images=False)

    checkpoint = ModelCheckpoint('macro_class_cnn_checkpoint.5',
                                 monitor='val_acc',
                                 verbose=1,
                                 save_best_only=True, mode='max',
                                 save_weights_only=True)

    # Fit the model! Using a bigger batch size and fewer epochs
    # because we have ~10K training images now instead of 100.
    history = model3.fit(
        X_train_norm, Y_train,
        batch_size=64, nb_epoch=50, verbose=1,
        validation_data=(X_valid_norm, Y_valid),
        callbacks=[checkpoint]
    )
else:
    model3.load_weights('macro_class_cnn.5')
In [29]:
# change to True to save
if False:
    model3.save('macro_class_cnn.h5')

Diagnosing what's going wrong...

As we saw before, the model is starting to overfit. Let's try to diagnose what's going on, then decide what to do. Let's start by looking at the confusion matrices again.

In [30]:
# Get the predictions
predict_train = model3.predict(X_train_norm)
predict_valid = model3.predict(X_valid_norm)
predict_test = model3.predict(X_test_norm)
In [31]:
plot_confusion_matrix(Y_test, predict_test, macro_classes,
                      normalize=False,
                      title="Test confusion matrix");
Confusion matrix, without normalization
[[ 24  95   0   1  49 232   0   0]
 [  7 233   0   6  35 265   0   0]
 [  0   7   0   2  36  52   0   0]
 [  0  10   0 117 109  61   0   0]
 [  3  13   0  26 363 182   0   0]
 [  6  51   0   8 103 636   0   0]
 [  2   2   0   5  50  54   0   0]
 [  1  44   0   6  56 207   0   0]]
In [32]:
plot_confusion_matrix(Y_train, predict_train, macro_classes,
                      title="Train confusion matrix");
Confusion matrix, without normalization
[[  56  283    0    9  151  757    0    0]
 [  14  723    0   14  138  820    0    0]
 [   0   12    0    3   66  223    0    0]
 [   0   23    0  333  377  187    0    0]
 [   0   40    0   76 1198  520    0    0]
 [   5  166    0   18  333 1992    0    0]
 [   0   20    0   18  149  165    0    0]
 [   3  126    0   16  151  682    0    0]]

Note: normalized confusion matrices can be helpful...

In [39]:
# Normalized to see per-class behavior better
plot_confusion_matrix(Y_train, predict_train, macro_classes,                      
                      title="Train confusion matrix", normalize=True);
Normalized confusion matrix

Well, it seems that most car types are classified as sedan. Not too surprising, especially given that sedans are overrepresented. It's starting to learn that SUVs and pickups are different from sedans, and occasionally manages to distinguish coupes from sedans.

So far, it doesn't use minivan, van, or wagon labels at all.

Things to check / do:

  • Is it correctly getting coupe images from the side, incorrectly from the front or back?
  • Look at class probabilities for some images, not just the maximal one
  • Count how many training images we have for each class. May want to oversample the low prob classes.
  • Try fine-tuning an off-the-shelf model.
In [43]:
# What's our class balance
xs, counts = np.unique(np.argmax(Y_train, axis=1),return_counts=True)
plt.bar(xs, counts, tick_label=macro_classes, align='center')
Out[43]:
<Container object of 8 artists>

Let's look at some mistakes on the train set

In [44]:
predict_train_labels = np.argmax(predict_train, axis=1)
correct_labels = np.argmax(Y_train, axis=1)
In [62]:
correct_train = np.where(predict_train_labels==correct_labels)[0]
wrong_train = np.where(predict_train_labels!=correct_labels)[0]
percent = 100 * len(correct_train)/float(len(correct_labels))
print("Training: {:.2f}% correct".format(percent))
Training: 43.60% correct
In [64]:
n_to_view = 20
subset = np.random.choice(correct_train, n_to_view, replace=False)
fig = plot_data(X_train_norm[subset], Y_train[subset], predict_train[subset]);
fig.suptitle("Correct predictions")
Out[64]:
<matplotlib.text.Text at 0x1bc1c8410>

Note that plot_data uses red whenever the probability of the correct class is <0.8. This doesn't make much sense with 8 classes -- would be nice to change it. These were all in fact "correct" -- true class had maximum prob.

In [65]:
n_to_view = 20
subset = np.random.choice(wrong_train, n_to_view, replace=False)
fig = plot_data(X_train_norm[subset], Y_train[subset], predict_train[subset]);
fig.suptitle("Wrong predictions")
Out[65]:
<matplotlib.text.Text at 0x1bf962d50>

Zooming in on the coupe class in particular

Is there a pattern for when it's confusing coupes and sedans?

In [71]:
correct_coupe = np.where((predict_train_labels==correct_labels) & (correct_labels==macro_class_map['Coupe']))[0]
wrong_coupe = np.where((predict_train_labels!=correct_labels) & (correct_labels==macro_class_map['Coupe']))[0]

n_to_view = 40
subset = np.random.choice(correct_coupe, n_to_view, replace=False)
fig = plot_data(X_train_norm[subset], Y_train[subset], predict_train[subset]);
fig.suptitle("Correct coupe predictions")

subset = np.random.choice(wrong_coupe, n_to_view, replace=False)
fig = plot_data(X_train_norm[subset], Y_train[subset], predict_train[subset]);
fig.suptitle("Wrong coupe predictions", fontsize=18)
Out[71]:
<matplotlib.text.Text at 0x1c73a9ad0>

Should make some functions to make this kind of analysis easier...

I don't see a clear pattern--side and front and back views in both sets of labels. One option is to keep training our own network, then come back and fight overfitting somehow. Instead, lets try transfer learning using squeezenet.

Try transfer learning

Let's try to use a model that's been trained on the imagenet dataset (1 million images!) That's generally a better place to start for this kind of vision problem than training from scratch on a small dataset.

First, get SqueezeNet set up

https://github.com/rcmalli/keras-squeezenet

In [72]:
!pip install keras_squeezenet
Collecting keras_squeezenet
  Downloading keras_squeezenet-0.3.tar.gz
Requirement already satisfied: numpy in /Users/shnayder/anaconda/lib/python2.7/site-packages (from keras_squeezenet)
Requirement already satisfied: pillow in /Users/shnayder/anaconda/lib/python2.7/site-packages (from keras_squeezenet)
Requirement already satisfied: tensorflow in /Users/shnayder/anaconda/lib/python2.7/site-packages (from keras_squeezenet)
Requirement already satisfied: keras in /Users/shnayder/anaconda/lib/python2.7/site-packages (from keras_squeezenet)
Requirement already satisfied: olefile in /Users/shnayder/anaconda/lib/python2.7/site-packages (from pillow->keras_squeezenet)
Requirement already satisfied: mock>=2.0.0 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from tensorflow->keras_squeezenet)
Requirement already satisfied: six>=1.10.0 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from tensorflow->keras_squeezenet)
Requirement already satisfied: protobuf>=3.1.0 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from tensorflow->keras_squeezenet)
Requirement already satisfied: wheel in /Users/shnayder/anaconda/lib/python2.7/site-packages (from tensorflow->keras_squeezenet)
Requirement already satisfied: theano in /Users/shnayder/anaconda/lib/python2.7/site-packages (from keras->keras_squeezenet)
Requirement already satisfied: pyyaml in /Users/shnayder/anaconda/lib/python2.7/site-packages (from keras->keras_squeezenet)
Requirement already satisfied: funcsigs>=1; python_version < "3.3" in /Users/shnayder/anaconda/lib/python2.7/site-packages (from mock>=2.0.0->tensorflow->keras_squeezenet)
Requirement already satisfied: pbr>=0.11 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from mock>=2.0.0->tensorflow->keras_squeezenet)
Requirement already satisfied: setuptools in /Users/shnayder/anaconda/lib/python2.7/site-packages (from protobuf>=3.1.0->tensorflow->keras_squeezenet)
Requirement already satisfied: scipy>=0.11 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from theano->keras->keras_squeezenet)
Requirement already satisfied: appdirs>=1.4.0 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from setuptools->protobuf>=3.1.0->tensorflow->keras_squeezenet)
Requirement already satisfied: packaging>=16.8 in /Users/shnayder/anaconda/lib/python2.7/site-packages (from setuptools->protobuf>=3.1.0->tensorflow->keras_squeezenet)
Requirement already satisfied: pyparsing in /Users/shnayder/anaconda/lib/python2.7/site-packages (from packaging>=16.8->setuptools->protobuf>=3.1.0->tensorflow->keras_squeezenet)
Building wheels for collected packages: keras-squeezenet
  Running setup.py bdist_wheel for keras-squeezenet ... - done
  Stored in directory: /Users/shnayder/Library/Caches/pip/wheels/94/37/36/900c81337d77ce40e20250809226b3690651f50be94d51f186
Successfully built keras-squeezenet
Installing collected packages: keras-squeezenet
Successfully installed keras-squeezenet-0.3
In [73]:
from keras_squeezenet import SqueezeNet
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
from keras.preprocessing import image
In [ ]:
model = SqueezeNet()
In [85]:
# Oops -- screwed up the first training image (see note about preprocess_input below)
X_train[0]
Out[85]:
array([[[-883.11999512, -820.01098633, -700.4510498 ],
        [-883.11999512, -820.01098633, -700.4510498 ],
        [-883.11999512, -820.01098633, -700.4510498 ],
        ..., 
        [-918.11999512, -855.01098633, -735.4510498 ],
        [-918.11999512, -855.01098633, -735.4510498 ],
        [-918.11999512, -855.01098633, -735.4510498 ]],

       [[-882.11999512, -819.01098633, -699.4510498 ],
        [-882.11999512, -819.01098633, -699.4510498 ],
        [-882.11999512, -819.01098633, -699.4510498 ],
        ..., 
        [-916.11999512, -853.01098633, -733.4510498 ],
        [-916.11999512, -853.01098633, -733.4510498 ],
        [-916.11999512, -853.01098633, -733.4510498 ]],

       [[-881.11999512, -818.01098633, -698.4510498 ],
        [-881.11999512, -818.01098633, -698.4510498 ],
        [-881.11999512, -818.01098633, -698.4510498 ],
        ..., 
        [-913.11999512, -850.01098633, -730.4510498 ],
        [-913.11999512, -850.01098633, -730.4510498 ],
        [-913.11999512, -850.01098633, -730.4510498 ]],

       ..., 
       [[-911.11999512, -848.01098633, -730.4510498 ],
        [-911.11999512, -848.01098633, -730.4510498 ],
        [-910.11999512, -847.01098633, -729.4510498 ],
        ..., 
        [-912.11999512, -849.01098633, -729.4510498 ],
        [-912.11999512, -849.01098633, -729.4510498 ],
        [-912.11999512, -849.01098633, -729.4510498 ]],

       [[-911.11999512, -848.01098633, -730.4510498 ],
        [-911.11999512, -848.01098633, -730.4510498 ],
        [-910.11999512, -847.01098633, -729.4510498 ],
        ..., 
        [-912.11999512, -849.01098633, -729.4510498 ],
        [-912.11999512, -849.01098633, -729.4510498 ],
        [-912.11999512, -849.01098633, -729.4510498 ]],

       [[-911.11999512, -848.01098633, -730.4510498 ],
        [-911.11999512, -848.01098633, -730.4510498 ],
        [-910.11999512, -847.01098633, -729.4510498 ],
        ..., 
        [-912.11999512, -849.01098633, -729.4510498 ],
        [-912.11999512, -849.01098633, -729.4510498 ],
        [-912.11999512, -849.01098633, -729.4510498 ]]], dtype=float32)
In [87]:
img = X_train[1]
plt.imshow(img/255.0)
x = img# image.img_to_array(img)
x = np.expand_dims(x, axis=0)
# preprocess_input modifies its argument!
x = preprocess_input(x.copy())

preds = model.predict(x)

print('Predicted:', decode_predicktions(preds))
('Predicted:', [[(u'n04285008', u'sports_car', 0.5435679), (u'n02974003', u'car_wheel', 0.41837999), (u'n03100240', u'convertible', 0.017516876), (u'n02814533', u'beach_wagon', 0.013616056), (u'n03459775', u'grille', 0.0029945327)]])
In [134]:
# screwed up X_train[0] earlier. Rather than rerun, I tried to hack/fix it manually 
# (didn't really work, but it's just one image, so decided not to care). Now it looks funny, and is a good
# reminder to check your data throughout your pipeline, not just once at the beginning...
plt.imshow(X_train[0])
Out[134]:
<matplotlib.image.AxesImage at 0x1e11a3150>

Adapting Squeezenet

Let's replace the output layer with a smaller classifier.

Following some comments at https://github.com/fchollet/keras/issues/2371 and our code in 07-transfer.

In [89]:
model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_8 (InputLayer)             (None, 227, 227, 3)   0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 113, 113, 64)  1792                                         
____________________________________________________________________________________________________
relu_conv1 (Activation)          (None, 113, 113, 64)  0                                            
____________________________________________________________________________________________________
pool1 (MaxPooling2D)             (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    1040                                         
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    2064                                         
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
pool3 (MaxPooling2D)             (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    4128                                         
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    8224                                         
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
pool5 (MaxPooling2D)             (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire6/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    12336                                        
____________________________________________________________________________________________________
fire6/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire6/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire6/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire6/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire7/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    18480                                        
____________________________________________________________________________________________________
fire7/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire7/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire7/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire7/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire8/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    24640                                        
____________________________________________________________________________________________________
fire8/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire8/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire8/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire8/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
fire9/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    32832                                        
____________________________________________________________________________________________________
fire9/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire9/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire9/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire9/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
drop9 (Dropout)                  (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
conv10 (Conv2D)                  (None, 13, 13, 1000)  513000                                       
____________________________________________________________________________________________________
relu_conv10 (Activation)         (None, 13, 13, 1000)  0                                            
____________________________________________________________________________________________________
global_average_pooling2d_8 (Glob (None, 1000)          0                                            
____________________________________________________________________________________________________
loss (Activation)                (None, 1000)          0                                            
====================================================================================================
Total params: 1,235,496.0
Trainable params: 1,235,496.0
Non-trainable params: 0.0
____________________________________________________________________________________________________

Modify the model to compute bottleneck features...

Leave out the final classification layers.

In [140]:
# We want to pull out the activations before conv10...

from keras.models import Model

# Get input
new_input = model.input
# Find the layer to connect
hidden_layer = model.get_layer('drop9').output
# Build a new model
bottleneck_model = Model(new_input, hidden_layer)
bottleneck_model.summary()
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
input_8 (InputLayer)             (None, 227, 227, 3)   0                                            
____________________________________________________________________________________________________
conv1 (Conv2D)                   (None, 113, 113, 64)  1792                                         
____________________________________________________________________________________________________
relu_conv1 (Activation)          (None, 113, 113, 64)  0                                            
____________________________________________________________________________________________________
pool1 (MaxPooling2D)             (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    1040                                         
____________________________________________________________________________________________________
fire2/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire2/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire2/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire2/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire2/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
fire3/squeeze1x1 (Conv2D)        (None, 56, 56, 16)    2064                                         
____________________________________________________________________________________________________
fire3/relu_squeeze1x1 (Activatio (None, 56, 56, 16)    0                                            
____________________________________________________________________________________________________
fire3/expand1x1 (Conv2D)         (None, 56, 56, 64)    1088                                         
____________________________________________________________________________________________________
fire3/expand3x3 (Conv2D)         (None, 56, 56, 64)    9280                                         
____________________________________________________________________________________________________
fire3/relu_expand1x1 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/relu_expand3x3 (Activation (None, 56, 56, 64)    0                                            
____________________________________________________________________________________________________
fire3/concat (Concatenate)       (None, 56, 56, 128)   0                                            
____________________________________________________________________________________________________
pool3 (MaxPooling2D)             (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    4128                                         
____________________________________________________________________________________________________
fire4/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire4/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire4/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire4/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire4/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
fire5/squeeze1x1 (Conv2D)        (None, 27, 27, 32)    8224                                         
____________________________________________________________________________________________________
fire5/relu_squeeze1x1 (Activatio (None, 27, 27, 32)    0                                            
____________________________________________________________________________________________________
fire5/expand1x1 (Conv2D)         (None, 27, 27, 128)   4224                                         
____________________________________________________________________________________________________
fire5/expand3x3 (Conv2D)         (None, 27, 27, 128)   36992                                        
____________________________________________________________________________________________________
fire5/relu_expand1x1 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/relu_expand3x3 (Activation (None, 27, 27, 128)   0                                            
____________________________________________________________________________________________________
fire5/concat (Concatenate)       (None, 27, 27, 256)   0                                            
____________________________________________________________________________________________________
pool5 (MaxPooling2D)             (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire6/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    12336                                        
____________________________________________________________________________________________________
fire6/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire6/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire6/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire6/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire6/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire7/squeeze1x1 (Conv2D)        (None, 13, 13, 48)    18480                                        
____________________________________________________________________________________________________
fire7/relu_squeeze1x1 (Activatio (None, 13, 13, 48)    0                                            
____________________________________________________________________________________________________
fire7/expand1x1 (Conv2D)         (None, 13, 13, 192)   9408                                         
____________________________________________________________________________________________________
fire7/expand3x3 (Conv2D)         (None, 13, 13, 192)   83136                                        
____________________________________________________________________________________________________
fire7/relu_expand1x1 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/relu_expand3x3 (Activation (None, 13, 13, 192)   0                                            
____________________________________________________________________________________________________
fire7/concat (Concatenate)       (None, 13, 13, 384)   0                                            
____________________________________________________________________________________________________
fire8/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    24640                                        
____________________________________________________________________________________________________
fire8/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire8/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire8/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire8/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire8/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
fire9/squeeze1x1 (Conv2D)        (None, 13, 13, 64)    32832                                        
____________________________________________________________________________________________________
fire9/relu_squeeze1x1 (Activatio (None, 13, 13, 64)    0                                            
____________________________________________________________________________________________________
fire9/expand1x1 (Conv2D)         (None, 13, 13, 256)   16640                                        
____________________________________________________________________________________________________
fire9/expand3x3 (Conv2D)         (None, 13, 13, 256)   147712                                       
____________________________________________________________________________________________________
fire9/relu_expand1x1 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/relu_expand3x3 (Activation (None, 13, 13, 256)   0                                            
____________________________________________________________________________________________________
fire9/concat (Concatenate)       (None, 13, 13, 512)   0                                            
____________________________________________________________________________________________________
drop9 (Dropout)                  (None, 13, 13, 512)   0                                            
====================================================================================================
Total params: 722,496.0
Trainable params: 722,496.0
Non-trainable params: 0.0
____________________________________________________________________________________________________
In [144]:
16000 * 13 * 13 * 512 * 4 / 2**20
Out[144]:
5281

This will generate 5G of saved features! I guess I can save 20% by not pre-computing the test ones, but doesn't seem worth it. Instead, let's first train on a subset. We'll ignore the test set entirely, and see if we can get reasonable validation performance by using 2000 training images and 1000 validation.

In [146]:
train_subset = 2000
valid_subset = 1000
In [148]:
def save_bottlebeck_features(bottleneck_model, xs, name):
    # don't change the param!
    xs = preprocess_input(xs.copy())
    bottleneck_features = bottleneck_model.predict(xs)
    
    with open('cars_bottleneck_features_{}.npy'.format(name), 'w') as f:
        np.save(f, bottleneck_features)

def save_labels(ys, name):
    with open('cars_bottleneck_labels_{}.npy'.format(name), 'w') as f:
        np.save(f, ys)

        
if False: # change to True to recompute  
    save_bottlebeck_features(bottleneck_model, X_train[:train_subset], 'train_subset')
    save_bottlebeck_features(bottleneck_model, X_valid[:valid_subset], 'valid_subset')
    # save_bottlebeck_features(bottleneck_model, X_test, 'test')
    
    save_labels(Y_train[:train_subset], 'train_subset')
    save_labels(Y_valid[:valid_subset], 'valid_subset')
In [149]:
!ls -lh cars*
-rw-r--r--@ 1 shnayder  staff   660M Mar 21 22:30 cars_bottleneck_features_train_subset.npy
-rw-r--r--@ 1 shnayder  staff   330M Mar 21 22:31 cars_bottleneck_features_valid_subset.npy
-rw-r--r--@ 1 shnayder  staff   125K Mar 21 22:32 cars_bottleneck_labels_train_subset.npy
-rw-r--r--@ 1 shnayder  staff    63K Mar 21 22:32 cars_bottleneck_labels_valid_subset.npy
In [159]:
def load_features(name):
    with open('cars_bottleneck_features_{}.npy'.format(name), 'r') as f:
        return np.load(f)

def load_labels(name):
    with open('cars_bottleneck_labels_{}.npy'.format(name)) as f:
        return np.load(f)


top_model_weights_path = 'cars_bottleneck_fc_model.h5'    
    
# Now let's train the model -- we'll put the same squeezenet structure, just with fewer classes
def make_top_model():
    inputs = Input((13,13,512))
    x = Convolution2D(len(macro_classes), (1, 1), padding='valid', name='new_conv10')(inputs)
    x = Activation('relu', name='new_relu_conv10')(x)
    x = GlobalAveragePooling2D()(x)
    out = Activation('softmax', name='loss')(x)

    model = Model(inputs, out, name='squeezed_top')
    
    model.compile(optimizer='rmsprop',
                  loss='binary_crossentropy', metrics=['accuracy'])
    return model

top_model = make_top_model()
print(top_model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_12 (InputLayer)        (None, 13, 13, 512)       0         
_________________________________________________________________
new_conv10 (Conv2D)          (None, 13, 13, 8)         4104      
_________________________________________________________________
new_relu_conv10 (Activation) (None, 13, 13, 8)         0         
_________________________________________________________________
global_average_pooling2d_12  (None, 8)                 0         
_________________________________________________________________
loss (Activation)            (None, 8)                 0         
=================================================================
Total params: 4,104.0
Trainable params: 4,104.0
Non-trainable params: 0.0
_________________________________________________________________
None
In [155]:
train_data = load_features('train_subset')
train_labels = load_labels('train_subset')

valid_data = load_features('valid_subset')
valid_labels = load_labels('valid_subset')
In [160]:
epochs = 50
batch_size = 128
history = top_model.fit(train_data, train_labels,
              epochs=epochs,
              batch_size=batch_size,
              validation_data=(valid_data, valid_labels))

top_model.save_weights(top_model_weights_path)
Train on 2000 samples, validate on 1000 samples
Epoch 1/50
2000/2000 [==============================] - 2s - loss: 0.6193 - acc: 0.8186 - val_loss: 0.5271 - val_acc: 0.8326
Epoch 2/50
2000/2000 [==============================] - 1s - loss: 0.4631 - acc: 0.8430 - val_loss: 0.4296 - val_acc: 0.8470
Epoch 3/50
2000/2000 [==============================] - 1s - loss: 0.3916 - acc: 0.8594 - val_loss: 0.3826 - val_acc: 0.8586
Epoch 4/50
2000/2000 [==============================] - 1s - loss: 0.3539 - acc: 0.8688 - val_loss: 0.3588 - val_acc: 0.8633
Epoch 5/50
2000/2000 [==============================] - 1s - loss: 0.3310 - acc: 0.8729 - val_loss: 0.3467 - val_acc: 0.8650
Epoch 6/50
2000/2000 [==============================] - 1s - loss: 0.3140 - acc: 0.8790 - val_loss: 0.3346 - val_acc: 0.8677
Epoch 7/50
2000/2000 [==============================] - 1s - loss: 0.3008 - acc: 0.8834 - val_loss: 0.3245 - val_acc: 0.8720
Epoch 8/50
2000/2000 [==============================] - 1s - loss: 0.2859 - acc: 0.8877 - val_loss: 0.3173 - val_acc: 0.8706
Epoch 9/50
2000/2000 [==============================] - 1s - loss: 0.2730 - acc: 0.8912 - val_loss: 0.3078 - val_acc: 0.8776
Epoch 10/50
2000/2000 [==============================] - 1s - loss: 0.2601 - acc: 0.8973 - val_loss: 0.3035 - val_acc: 0.8770
Epoch 11/50
2000/2000 [==============================] - 1s - loss: 0.2503 - acc: 0.9011 - val_loss: 0.2939 - val_acc: 0.8851
Epoch 12/50
2000/2000 [==============================] - 1s - loss: 0.2413 - acc: 0.9045 - val_loss: 0.2904 - val_acc: 0.8833
Epoch 13/50
2000/2000 [==============================] - 1s - loss: 0.2317 - acc: 0.9077 - val_loss: 0.2868 - val_acc: 0.8860
Epoch 14/50
2000/2000 [==============================] - 1s - loss: 0.2263 - acc: 0.9117 - val_loss: 0.2957 - val_acc: 0.8809
Epoch 15/50
2000/2000 [==============================] - 1s - loss: 0.2182 - acc: 0.9154 - val_loss: 0.2878 - val_acc: 0.8860
Epoch 16/50
2000/2000 [==============================] - 1s - loss: 0.2136 - acc: 0.9163 - val_loss: 0.2806 - val_acc: 0.8874
Epoch 17/50
2000/2000 [==============================] - 1s - loss: 0.2079 - acc: 0.9196 - val_loss: 0.2791 - val_acc: 0.8886
Epoch 18/50
2000/2000 [==============================] - 1s - loss: 0.2004 - acc: 0.9212 - val_loss: 0.2737 - val_acc: 0.8906
Epoch 19/50
2000/2000 [==============================] - 1s - loss: 0.1956 - acc: 0.9244 - val_loss: 0.2715 - val_acc: 0.8922
Epoch 20/50
2000/2000 [==============================] - 1s - loss: 0.1913 - acc: 0.9263 - val_loss: 0.2741 - val_acc: 0.8915
Epoch 21/50
2000/2000 [==============================] - 1s - loss: 0.1886 - acc: 0.9273 - val_loss: 0.2737 - val_acc: 0.8896
Epoch 22/50
2000/2000 [==============================] - 1s - loss: 0.1835 - acc: 0.9288 - val_loss: 0.2842 - val_acc: 0.8846
Epoch 23/50
2000/2000 [==============================] - 1s - loss: 0.1809 - acc: 0.9307 - val_loss: 0.2761 - val_acc: 0.8925
Epoch 24/50
2000/2000 [==============================] - 1s - loss: 0.1765 - acc: 0.9328 - val_loss: 0.2688 - val_acc: 0.8939
Epoch 25/50
2000/2000 [==============================] - 1s - loss: 0.1717 - acc: 0.9342 - val_loss: 0.2690 - val_acc: 0.8923
Epoch 26/50
2000/2000 [==============================] - 1s - loss: 0.1731 - acc: 0.9327 - val_loss: 0.2693 - val_acc: 0.8924
Epoch 27/50
2000/2000 [==============================] - 1s - loss: 0.1688 - acc: 0.9362 - val_loss: 0.2666 - val_acc: 0.8970
Epoch 28/50
2000/2000 [==============================] - 1s - loss: 0.1644 - acc: 0.9378 - val_loss: 0.2644 - val_acc: 0.8952
Epoch 29/50
2000/2000 [==============================] - 1s - loss: 0.1601 - acc: 0.9393 - val_loss: 0.2677 - val_acc: 0.8941
Epoch 30/50
2000/2000 [==============================] - 1s - loss: 0.1594 - acc: 0.9401 - val_loss: 0.2663 - val_acc: 0.8986
Epoch 31/50
2000/2000 [==============================] - 1s - loss: 0.1573 - acc: 0.9407 - val_loss: 0.2717 - val_acc: 0.8927
Epoch 32/50
2000/2000 [==============================] - 1s - loss: 0.1555 - acc: 0.9399 - val_loss: 0.2661 - val_acc: 0.8953
Epoch 33/50
2000/2000 [==============================] - 1s - loss: 0.1541 - acc: 0.9420 - val_loss: 0.2746 - val_acc: 0.8934
Epoch 34/50
2000/2000 [==============================] - 1s - loss: 0.1486 - acc: 0.9447 - val_loss: 0.2700 - val_acc: 0.8958
Epoch 35/50
2000/2000 [==============================] - 1s - loss: 0.1485 - acc: 0.9454 - val_loss: 0.2682 - val_acc: 0.8949
Epoch 36/50
2000/2000 [==============================] - 1s - loss: 0.1459 - acc: 0.9454 - val_loss: 0.2754 - val_acc: 0.8946
Epoch 37/50
2000/2000 [==============================] - 1s - loss: 0.1440 - acc: 0.9471 - val_loss: 0.2714 - val_acc: 0.8940
Epoch 38/50
2000/2000 [==============================] - 1s - loss: 0.1427 - acc: 0.9484 - val_loss: 0.2706 - val_acc: 0.8931
Epoch 39/50
2000/2000 [==============================] - 1s - loss: 0.1399 - acc: 0.9500 - val_loss: 0.2743 - val_acc: 0.8961
Epoch 40/50
2000/2000 [==============================] - 1s - loss: 0.1372 - acc: 0.9486 - val_loss: 0.2662 - val_acc: 0.8976
Epoch 41/50
2000/2000 [==============================] - 1s - loss: 0.1378 - acc: 0.9503 - val_loss: 0.2740 - val_acc: 0.8957
Epoch 42/50
2000/2000 [==============================] - 1s - loss: 0.1338 - acc: 0.9511 - val_loss: 0.2718 - val_acc: 0.8951
Epoch 43/50
2000/2000 [==============================] - 1s - loss: 0.1325 - acc: 0.9507 - val_loss: 0.2748 - val_acc: 0.8953
Epoch 44/50
2000/2000 [==============================] - 1s - loss: 0.1309 - acc: 0.9523 - val_loss: 0.2748 - val_acc: 0.8945
Epoch 45/50
2000/2000 [==============================] - 1s - loss: 0.1289 - acc: 0.9531 - val_loss: 0.2757 - val_acc: 0.8931
Epoch 46/50
2000/2000 [==============================] - 1s - loss: 0.1274 - acc: 0.9536 - val_loss: 0.2720 - val_acc: 0.8916
Epoch 47/50
2000/2000 [==============================] - 1s - loss: 0.1270 - acc: 0.9554 - val_loss: 0.2906 - val_acc: 0.8897
Epoch 48/50
2000/2000 [==============================] - 1s - loss: 0.1257 - acc: 0.9552 - val_loss: 0.2793 - val_acc: 0.8948
Epoch 49/50
2000/2000 [==============================] - 1s - loss: 0.1218 - acc: 0.9572 - val_loss: 0.2824 - val_acc: 0.8935
Epoch 50/50
2000/2000 [==============================] - 1s - loss: 0.1240 - acc: 0.9537 - val_loss: 0.2772 - val_acc: 0.8944
In [162]:
plot_training_curves(history.history);

Almost 90% validation accuracy! Clearly we could have stopped earlier. Let's take a quick look at the confusion matrix.

In [163]:
predict_train = top_model.predict(train_data)
In [166]:
plot_confusion_matrix(train_labels, predict_train, macro_classes,                      
                      title="Train confusion matrix");
plt.figure()
plot_confusion_matrix(train_labels, predict_train, macro_classes,                      
                      title="Train confusion matrix",
                     normalize=True);
Confusion matrix, without normalization
Normalized confusion matrix

So we're making relatively few mistakes with pickups and vans and sedans, somewhat more with SUVs, and confusing wagons, convertibles, and coupes for sedans. Makes sense. Perhaps we should combine all those classes together anyway.

We could fine-tune the network and do data augmentation too. For now, let's just train on the rest of our training data.

In [167]:
def compute_bottleneck_features(xs):
    xs = preprocess_input(xs.copy())
    return bottleneck_model.predict(xs)

rest_train_data = compute_bottleneck_features(X_train[train_subset:])
rest_train_labels = Y_train[train_subset:]
In [168]:
epochs = 50
batch_size = 128
history2 = top_model.fit(rest_train_data, rest_train_labels,
               epochs=epochs,
               batch_size=batch_size,
               validation_data=(valid_data, valid_labels))
Train on 7867 samples, validate on 1000 samples
Epoch 1/50
7867/7867 [==============================] - 6s - loss: 0.2646 - acc: 0.8978 - val_loss: 0.2447 - val_acc: 0.9044
Epoch 2/50
7867/7867 [==============================] - 5s - loss: 0.2384 - acc: 0.9052 - val_loss: 0.2610 - val_acc: 0.8935
Epoch 3/50
7867/7867 [==============================] - 5s - loss: 0.2276 - acc: 0.9089 - val_loss: 0.2278 - val_acc: 0.9056
Epoch 4/50
7867/7867 [==============================] - 5s - loss: 0.2182 - acc: 0.9127 - val_loss: 0.2317 - val_acc: 0.9039
Epoch 5/50
7867/7867 [==============================] - 5s - loss: 0.2114 - acc: 0.9147 - val_loss: 0.2274 - val_acc: 0.9067
Epoch 6/50
7867/7867 [==============================] - 5s - loss: 0.2061 - acc: 0.9168 - val_loss: 0.2228 - val_acc: 0.9067
Epoch 7/50
7867/7867 [==============================] - 5s - loss: 0.2004 - acc: 0.9180 - val_loss: 0.2291 - val_acc: 0.9046
Epoch 8/50
7867/7867 [==============================] - 5s - loss: 0.1959 - acc: 0.9205 - val_loss: 0.2162 - val_acc: 0.9089
Epoch 9/50
7867/7867 [==============================] - 5s - loss: 0.1929 - acc: 0.9217 - val_loss: 0.2090 - val_acc: 0.9131
Epoch 10/50
7867/7867 [==============================] - 5s - loss: 0.1893 - acc: 0.9238 - val_loss: 0.2074 - val_acc: 0.9129
Epoch 11/50
7867/7867 [==============================] - 5s - loss: 0.1847 - acc: 0.9258 - val_loss: 0.2123 - val_acc: 0.9113
Epoch 12/50
7867/7867 [==============================] - 5s - loss: 0.1831 - acc: 0.9251 - val_loss: 0.2040 - val_acc: 0.9151
Epoch 13/50
7867/7867 [==============================] - 5s - loss: 0.1797 - acc: 0.9276 - val_loss: 0.2123 - val_acc: 0.9103
Epoch 14/50
7867/7867 [==============================] - 5s - loss: 0.1777 - acc: 0.9281 - val_loss: 0.2091 - val_acc: 0.9146
Epoch 15/50
7867/7867 [==============================] - 5s - loss: 0.1752 - acc: 0.9290 - val_loss: 0.2073 - val_acc: 0.9141
Epoch 16/50
7867/7867 [==============================] - 5s - loss: 0.1744 - acc: 0.9287 - val_loss: 0.2069 - val_acc: 0.9141
Epoch 17/50
7867/7867 [==============================] - 5s - loss: 0.1720 - acc: 0.9305 - val_loss: 0.2157 - val_acc: 0.9100
Epoch 18/50
7867/7867 [==============================] - 5s - loss: 0.1704 - acc: 0.9307 - val_loss: 0.2038 - val_acc: 0.9155
Epoch 19/50
7867/7867 [==============================] - 5s - loss: 0.1687 - acc: 0.9312 - val_loss: 0.2051 - val_acc: 0.9157
Epoch 20/50
7867/7867 [==============================] - 5s - loss: 0.1663 - acc: 0.9326 - val_loss: 0.2085 - val_acc: 0.9136
Epoch 21/50
7867/7867 [==============================] - 5s - loss: 0.1641 - acc: 0.9335 - val_loss: 0.2030 - val_acc: 0.9178
Epoch 22/50
7867/7867 [==============================] - 5s - loss: 0.1625 - acc: 0.9347 - val_loss: 0.2062 - val_acc: 0.9151
Epoch 23/50
7867/7867 [==============================] - 5s - loss: 0.1630 - acc: 0.9343 - val_loss: 0.2046 - val_acc: 0.9154
Epoch 24/50
7867/7867 [==============================] - 5s - loss: 0.1606 - acc: 0.9351 - val_loss: 0.2185 - val_acc: 0.9099
Epoch 25/50
7867/7867 [==============================] - 5s - loss: 0.1595 - acc: 0.9356 - val_loss: 0.2082 - val_acc: 0.9157
Epoch 26/50
7867/7867 [==============================] - 5s - loss: 0.1594 - acc: 0.9357 - val_loss: 0.2245 - val_acc: 0.9074
Epoch 27/50
7867/7867 [==============================] - 5s - loss: 0.1565 - acc: 0.9365 - val_loss: 0.2056 - val_acc: 0.9156
Epoch 28/50
7867/7867 [==============================] - 5s - loss: 0.1572 - acc: 0.9368 - val_loss: 0.2059 - val_acc: 0.9175
Epoch 29/50
7867/7867 [==============================] - 5s - loss: 0.1548 - acc: 0.9378 - val_loss: 0.2174 - val_acc: 0.9101
Epoch 30/50
7867/7867 [==============================] - 5s - loss: 0.1545 - acc: 0.9379 - val_loss: 0.2179 - val_acc: 0.9116
Epoch 31/50
7867/7867 [==============================] - 5s - loss: 0.1550 - acc: 0.9372 - val_loss: 0.2136 - val_acc: 0.9126
Epoch 32/50
7867/7867 [==============================] - 5s - loss: 0.1537 - acc: 0.9375 - val_loss: 0.2172 - val_acc: 0.9100
Epoch 33/50
7867/7867 [==============================] - 4s - loss: 0.1519 - acc: 0.9390 - val_loss: 0.2060 - val_acc: 0.9159
Epoch 34/50
7867/7867 [==============================] - 5s - loss: 0.1523 - acc: 0.9386 - val_loss: 0.2093 - val_acc: 0.9160
Epoch 35/50
7867/7867 [==============================] - 5s - loss: 0.1500 - acc: 0.9402 - val_loss: 0.2085 - val_acc: 0.9153
Epoch 36/50
7867/7867 [==============================] - 5s - loss: 0.1511 - acc: 0.9394 - val_loss: 0.2169 - val_acc: 0.9142
Epoch 37/50
7867/7867 [==============================] - 5s - loss: 0.1495 - acc: 0.9399 - val_loss: 0.2074 - val_acc: 0.9134
Epoch 38/50
7867/7867 [==============================] - 5s - loss: 0.1488 - acc: 0.9404 - val_loss: 0.2231 - val_acc: 0.9129
Epoch 39/50
7867/7867 [==============================] - 5s - loss: 0.1488 - acc: 0.9404 - val_loss: 0.2103 - val_acc: 0.9146
Epoch 40/50
7867/7867 [==============================] - 5s - loss: 0.1473 - acc: 0.9408 - val_loss: 0.2223 - val_acc: 0.9131
Epoch 41/50
7867/7867 [==============================] - 5s - loss: 0.1465 - acc: 0.9414 - val_loss: 0.2160 - val_acc: 0.9140
Epoch 42/50
7867/7867 [==============================] - 5s - loss: 0.1478 - acc: 0.9414 - val_loss: 0.2162 - val_acc: 0.9146
Epoch 43/50
7867/7867 [==============================] - 5s - loss: 0.1468 - acc: 0.9420 - val_loss: 0.2102 - val_acc: 0.9146
Epoch 44/50
7867/7867 [==============================] - 5s - loss: 0.1457 - acc: 0.9425 - val_loss: 0.2150 - val_acc: 0.9159
Epoch 45/50
7867/7867 [==============================] - 5s - loss: 0.1448 - acc: 0.9421 - val_loss: 0.2261 - val_acc: 0.9125
Epoch 46/50
7867/7867 [==============================] - 4s - loss: 0.1449 - acc: 0.9422 - val_loss: 0.2164 - val_acc: 0.9151
Epoch 47/50
7867/7867 [==============================] - 5s - loss: 0.1449 - acc: 0.9424 - val_loss: 0.2057 - val_acc: 0.9156
Epoch 48/50
7867/7867 [==============================] - 5s - loss: 0.1442 - acc: 0.9423 - val_loss: 0.2099 - val_acc: 0.9149
Epoch 49/50
7867/7867 [==============================] - 5s - loss: 0.1441 - acc: 0.9418 - val_loss: 0.2167 - val_acc: 0.9144
Epoch 50/50
7867/7867 [==============================] - 5s - loss: 0.1431 - acc: 0.9429 - val_loss: 0.2324 - val_acc: 0.9045
In [169]:
from helpers import combine_histories
plot_training_curves(combine_histories(history.history, history2.history));

Ok, that got us about 2 more percent. Let's save these weights.

In [170]:
top_model.save_weights(top_model_weights_path)

Look at validation confusion matrix

So, what's the confusion matrix for our validation data like?

In [173]:
predict_valid = top_model.predict(valid_data)
In [175]:
top_model.evaluate(valid_data, valid_labels)
 992/1000 [============================>.] - ETA: 0s
Out[175]:
[0.2324370242357254, 0.90449999999999997]
In [174]:
plot_confusion_matrix(valid_labels, predict_valid, macro_classes,                      
                      title="Validation confusion matrix");
plt.figure()
plot_confusion_matrix(valid_labels, predict_valid, macro_classes,                      
                      title="Validation confusion matrix",
                     normalize=True);
Confusion matrix, without normalization
Normalized confusion matrix

Next steps...

We'll stop for now. If we wanted to continue, here are some things to do:

  • Continue to improve the classifier by using data augmentation
  • Fine-tune several layers of squeezenet
  • Go back to our raw data and use the bounding box info, or resize/crop differently.
  • Add more than one new layer on top of squeezenet.
  • Try a different pre-trained network.
  • To test out our final output, combine the really confusing classes as discussed above, and make a handy function to take an image and run it through the combined network to give a classification...